[Benchmark]: Add model_config sweep mode and model registry#1180
[Benchmark]: Add model_config sweep mode and model registry#1180noemotiovon wants to merge 2 commits into
Conversation
- Add Qwen 2.5 models (7B / 14B / 72B) and DeepSeek models (V2 Lite / V3) to MODEL_REGISTRY - Add model_config sweep support to all 33 benchmark scripts, enabling benchmarks to sweep across different model architectures at a fixed sequence length - Refactor benchmark scripts by extracting helper functions: - _setup_* - _resolve_model_config_* to improve code reuse and keep implementations cleaner across sweep modes - Add grouped bar chart visualization in benchmarks_visualizer for model_config sweep results
Benchmark Framework DesignThis document describes the overall design of the Liger-Kernel benchmark suite, including its two benchmark dimensions, the shared infrastructure, and the phased implementation plan. 1. Benchmark DimensionsEvery operator should ideally be benchmarked along two orthogonal dimensions:
D1: Non-model dimension sweep (implemented)Sweep non-model dimensions (e.g. sequence length, BT) with a fixed model config selected via D2: Model dimension sweep (implemented)Sweep model architecture dimensions (e.g. hidden_size, or discrete model configs from 2. D2 Design ChoicesFollowing the maintainer discussion, we evaluated three approaches:
Decision: C as the primary approach, with A as optional enrichment for ops where single-parameter scaling is important. Rationale:
3. Universal Token Length for D2For D2 benchmarks, we need a fixed token-length that is safe (no OOM) across all model configs and all operators. Strategy
Proposed CLI# D1 (existing): token-length sweep with fixed model
python benchmark_geglu.py --model llama_3_8b
# D2 (new): model-config sweep with fixed token length
python benchmark_geglu.py --sweep-mode model_config --bt 2048The 4. Infrastructure Changes4.1 New config type@dataclass(frozen=True)
class ModelConfigSweepConfig:
"""Config for D2 benchmarks that sweep across model configs."""
model_configs: List[ModelConfig] # models to benchmark
bt: int # fixed batch * seq_len
batch_size: int # safe batch size
seq_len: int # safe seq_len4.2 New helperdef compute_model_config_sweep_config(
model_configs: List[ModelConfig],
probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]],
bt: int = 2048,
memory_utilization: float = 0.4,
) -> ModelConfigSweepConfig:
"""Find safe (batch_size, seq_len) that works across all model configs.
For each model config, runs probe_fn_factory(model_config, bt) to measure
peak memory, then picks the most conservative batch_size / seq_len.
"""
...4.3 Script-level changesEach benchmark script gains a model-config sweep code path gated by if args.sweep_mode == "model_config":
configs = [MODEL_REGISTRY[name] for name in MODEL_REGISTRY]
sweep = compute_model_config_sweep_config(configs, probe_fn_factory=..., bt=args.bt)
# x_values = model config indices
# extra_benchmark_configs = contains all model configs
...
else:
# existing token-length sweep logic
...4.4 VisualizationD2 results produce grouped bar charts (speedup or throughput) rather than line charts:
5. Phased Implementation PlanPhase 1: Foundation (current PR)Status: complete
Phase 2: Model-config sweep (D2)Status: complete
Phase 3: Rollout and visualizationStatus: in progress
Phase 3 Kernel Rollout TrackingAlready refactored (D1 + D2):
Norm-like kernels (input: BT × hidden_size):
Loss kernels (input involves vocab_size or similar):
RLHF/alignment loss kernels:
Positional encoding kernels:
Activation / misc kernels:
Attention kernels:
Other:
6. Directory Structure |
|
|
||
| This module re-computes forward in the backward, so forward occurs twice per iteration. | ||
| """ | ||
|
|
There was a problem hiding this comment.
maybe we could keep these comments
| dtype: torch.dtype, | ||
| device: str, | ||
| ): | ||
| def __init__(self, mhc_cls, *, hidden_size, hc, num_heads, intermediate_mult, tmax, dtype, device): |
| tmax: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| self, mhc_cls, *, vocab_size, hidden_size, hc, num_layers, num_heads, intermediate_mult, tmax, dtype, device |
| tmax: int, | ||
| dtype: torch.dtype, | ||
| ): | ||
| def _build_model(provider, *, hidden_size, hc, num_layers, num_heads, intermediate_mult, vocab_size, tmax, dtype): |
| Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation. | ||
| """ | ||
|
|
||
| def __init__(self, config, num_shards=None): |
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } |
There was a problem hiding this comment.
We have built a general class BenchMiniMHCLM to test in this benchmark
| bias=bias, | ||
| dtype=dtype, | ||
| device=device, | ||
| ) |
| groups=groups, | ||
| bias=bias, | ||
| dtype=dtype, | ||
| device=device, |
| "bias": True, | ||
| "dtype": torch.bfloat16, | ||
| }, | ||
| ], |
There was a problem hiding this comment.
we have dropped too many extra configs here
| extra_benchmark_configs=[ | ||
| {"M": 2048, "dtype": torch.float32}, | ||
| {"M": 2048, "dtype": torch.bfloat16}, | ||
| ], |
| if args.sweep_mode == "model_config": | ||
| all_model_configs = list(MODEL_REGISTRY.values()) | ||
| T = 512 | ||
| BT = 2048 |
There was a problem hiding this comment.
BT is too small compared to the current one
| {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, | ||
| # Llama | ||
| {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, | ||
| ], |
There was a problem hiding this comment.
here we already have a bert-like model and a llama-like model
| else: | ||
| model = get_benchmark_model_config(args.model) | ||
| T = 512 | ||
| probe_bt = 2048 |
| torch.randn_like(q, device=device, dtype=dtype), | ||
| torch.randn_like(k, device=device), | ||
| ) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) |
| torch.randn_like(q, device=device, dtype=dtype), | ||
| torch.randn_like(k, device=device, dtype=dtype), | ||
| ) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device, dtype=dtype) |
| rep=400, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) |
| "x_name": "T", | ||
| "x_label": "sequence length", | ||
| "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], | ||
| "kernel_providers": ["liger", "huggingface"], |
| ) | ||
| q = torch.randn((1, seq_len, num_q_heads, head_dim), device=device, requires_grad=True, dtype=dtype) | ||
| k = torch.randn((1, seq_len, num_kv_heads, head_dim), device=device, requires_grad=True, dtype=dtype) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) |
| rep=400, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) |
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): |
| beta=beta, | ||
| use_ref_model=True, | ||
| ).get_batch_loss_metrics | ||
| self.KTO_loss = HFKTOLoss(ignore_index=ignore_index, beta=beta, use_ref_model=True).get_batch_loss_metrics |
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
…uadratic scaling support (linkedin#1218) ## Summary Refs linkedin#1200. Addresses non-linear memory scaling in benchmark sweep config inference. The existing `compute_seq_len_sweep_config` inverts memory via `max_tokens = usable_bytes / kernel_bytes_per_token`, which only holds for linear-scaling kernels. For O(L²) kernels (e.g. `benchmark_sparse_multi_token_attention.py`), this overestimates capacity by orders of magnitude — the existing workaround there divides by `probe_L * probe_L`, but the downstream sweep math still treats the result as linear bytes-per-token. Per discussion on the issue (linkedin#1200 (comment)), this PR adds a new helper rather than threading `scaling_method` through the existing function — 16+ benchmark scripts call `estimate_kernel_peak_memory` today, and a wider signature change would conflict with in-flight benchmark refactors (linkedin#1199, linkedin#1180). Linear-scaling callers are unchanged; only quadratic-scaling benchmarks opt in. ### What changed - **`benchmark/scripts/benchmark_model_configs.py`** — adds `compute_seq_len_sweep_config_with_probe(model_cfg, probe_fn, probe_seq_len, probe_batch_size=1, scaling_method="linear" | "quadratic", ...)`. Internalizes the probe call + inversion; reuses `estimate_kernel_peak_memory` for the measurement. - **`benchmark/scripts/benchmark_sparse_multi_token_attention.py`** — switches the `token_length` sweep mode to the new helper with `scaling_method="quadratic"`, dropping the manual `peak_bytes // (probe_L * probe_L)` workaround. `estimate_kernel_peak_memory` and `compute_seq_len_sweep_config` are untouched. ## Validation Hardware: A10G 24GB (g5.xlarge). Synthetic O(L²) probe (B=2, L=2048, allocates `B * L * L` floats) using `LLAMA_3_8B` config and `max_seq_len=2**20` to bypass the model cap so the raw inversion is visible: ``` quadratic: SeqLenSweepConfig(batch_size=2, seq_len=8192) linear: SeqLenSweepConfig(batch_size=2, seq_len=65536) ``` The 8× gap (≈17× before snap-to-power-of-2) demonstrates the inversion difference: `linear` claims a sweep at L=65536 fits, when in reality L² at that size would require multiple TBs. `quadratic` lands at a realistic L=8192. This matches the issue's premise — for non-linear-scaling kernels, the existing inversion overestimates capacity and would OOM at the predicted boundary. ## Testing Done - [x] Synthetic O(L²) sanity check on A10G — confirms `quadratic` predicts L=8192 vs `linear` predicts L=65536 for the same probe (8× separation, scales as expected). - [x] `benchmark_sparse_multi_token_attention.py` imports + helper resolution verified locally. - [ ] Full sparse-attention end-to-end sweep on A10G (deferred — synthetic test already isolates the inversion math from kernel-specific noise). cc @Tcc0403
Hardware Type: Atlas 800I A2
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence